{
 "cells": [
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parameter import Parameter\n",
    "from dgl.nn.pytorch import SAGEConv\n",
    "from dgl.nn.pytorch import GraphConv\n",
    "import numpy as np\n",
    "from layers import MultiHeadGATLayer, HAN_metapath_specific"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:55:20.436720Z",
     "start_time": "2024-05-08T00:55:16.780301400Z"
    }
   },
   "id": "a03b107709c7116a",
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Semantic-level Attention with metapath length 2"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "5df8e432c8e1d01d"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class SemanticAttention(nn.Module):\n",
    "    def __init__(self, in_size, hidden_size=128):\n",
    "        super(SemanticAttention, self).__init__()\n",
    "\n",
    "        self.project = nn.Sequential(\n",
    "            nn.Linear(in_size, hidden_size),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(hidden_size, 1, bias=False)\n",
    "        )\n",
    "\n",
    "    def forward(self, z):\n",
    "        w = self.project(z).mean(0)  # 权重矩阵，获得元路径的重要性 [2, 1] .mean(0):每个meta_path上的均值（/|V|）;\n",
    "        beta = torch.sigmoid(w)\n",
    "        beta = beta.expand((z.shape[0],) + beta.shape)  # [x,2,1] 扩展到N个节点上的metapath的值\n",
    "        return (beta * z).sum(1)\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:55:20.460718700Z",
     "start_time": "2024-05-08T00:55:20.440700900Z"
    }
   },
   "id": "34aac36e1bdc2eb9",
   "execution_count": 2
  },
  {
   "cell_type": "markdown",
   "source": [
    "# MAHN Model"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d9b6c427de781170"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class MAHN(nn.Module):\n",
    "    def __init__(self, G_D0, G_D1, G_ME0, G_ME1, hidden_dim, G, meta_paths_list, feature_attn_size, num_heads, num_diseases, num_metabolite, num_microbe,\n",
    "                 d_sim_dim, me_sim_dim, mi_sim_dim, out_dim, dropout, slope, device):\n",
    "        super(MAHN, self).__init__()\n",
    "        self.device = device\n",
    "        self.G_D0 = G_D0\n",
    "        self.G_D1 = G_D1\n",
    "        self.G_ME0 = G_ME0\n",
    "        self.G_ME1 = G_ME1\n",
    "        self.G = G\n",
    "        self.meta_paths = meta_paths_list\n",
    "        self.num_heads = num_heads\n",
    "        self.num_diseases = num_diseases\n",
    "        self.num_metabolite = num_metabolite\n",
    "        self.num_microbe = num_microbe\n",
    "        self.conv1 = GraphConv(d_sim_dim, hidden_dim)\n",
    "        self.sageconv = SAGEConv(me_sim_dim, hidden_dim, 'mean')\n",
    "        self.gat = MultiHeadGATLayer(G, feature_attn_size, num_heads, dropout, slope, device, merge='cat')\n",
    "        self.heads = nn.ModuleList()\n",
    "        self.metapath_layers = nn.ModuleList()\n",
    "        for i in range(self.num_heads):\n",
    "            self.metapath_layers.append(HAN_metapath_specific(G, feature_attn_size, out_dim, dropout, slope,device))\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.me_fc = nn.Linear(feature_attn_size * num_heads + me_sim_dim, out_dim*2)\n",
    "        self.d_fc = nn.Linear(feature_attn_size * num_heads + d_sim_dim, out_dim*2)\n",
    "        self.semantic_attention = SemanticAttention(in_size=out_dim * num_heads)\n",
    "        self.h_fc = nn.Linear(out_dim*4, out_dim)\n",
    "        self.predict = nn.Linear(out_dim * 2, 1)\n",
    "        self.BilinearDecoder = BilinearDecoder(feature_size=64)\n",
    "\n",
    "    def forward(self, G_D0, G_D1, G_ME0, G_ME1, G, G0, diseases, metabolite):\n",
    "        G_D0 = G_D0.to(self.device)\n",
    "        G_D0.ndata['d_sim'] = G_D0.ndata['d_sim'].to(self.device)\n",
    "        h_D0 = self.conv1(G_D0, G_D0.ndata['d_sim']).to(self.device)\n",
    "        h_D0 = self.dropout(F.elu(h_D0)).to(self.device)\n",
    "\n",
    "        G_D1 = G_D1.to(self.device)\n",
    "        G_D1.ndata['d_sim'] = G_D1.ndata['d_sim'].to(self.device)\n",
    "        h_D1 = self.conv1(G_D1, G_D1.ndata['d_sim']).to(self.device)\n",
    "        h_D1 = self.dropout(F.elu(h_D1)).to(self.device)\n",
    "\n",
    "        G_ME0 = G_ME0.to(self.device)\n",
    "        G_ME0.ndata['me_sim'] = G_ME0.ndata['me_sim'].to(self.device)\n",
    "        h_ME0 = self.sageconv(G_ME0, G_ME0.ndata['me_sim']).to(self.device)\n",
    "        h_ME0 = self.dropout(F.elu(h_ME0)).to(self.device)\n",
    "\n",
    "        G_ME1 = G_ME1.to(self.device)\n",
    "        G_ME1.ndata['me_sim'] = G_ME1.ndata['me_sim'].to(self.device)\n",
    "        h_ME1 = self.sageconv(G_ME1, G_ME1.ndata['me_sim']).to(self.device)\n",
    "        h_ME1 = self.dropout(F.elu(h_ME1)).to(self.device)\n",
    "\n",
    "        h_D = torch.cat((h_D0, h_D1), dim=1).to(self.device)\n",
    "        h_ME = torch.cat((h_ME0, h_ME1), dim=1).to(self.device)\n",
    "\n",
    "        index1 = 0\n",
    "        for meta_path in self.meta_paths:\n",
    "            if meta_path == 'dme' or meta_path == 'med':\n",
    "                if index1 == 0:\n",
    "                    h_agg0 = self.gat(G).to(self.device)\n",
    "                    index1 = 1\n",
    "            elif meta_path == 'dmi':\n",
    "                dmi_edges = G0.filter_edges(lambda edges: edges.data['dmi']).to(self.device)\n",
    "                g_dmi = G0.edge_subgraph(dmi_edges, preserve_nodes=True)\n",
    "                g_dmi = g_dmi.to(self.device)\n",
    "                head_outs0 = [attn_head(g_dmi, meta_path) for attn_head in self.metapath_layers]\n",
    "                h_agg1 = torch.cat(head_outs0, dim=1).to(self.device)\n",
    "            elif meta_path == 'mime':\n",
    "                mime_edges = G0.filter_edges(lambda edges: edges.data['mime']).to(self.device)\n",
    "                g_mime = G0.edge_subgraph(mime_edges, preserve_nodes=True)\n",
    "                g_mime=g_mime.to(self.device)\n",
    "                head_outs1 = [attn_head(g_mime, meta_path) for attn_head in self.metapath_layers]\n",
    "                h_agg2 = torch.cat(head_outs1, dim=1).to(self.device)\n",
    "\n",
    "        # 不同元路径疾病特征和不同元路径节点特征\n",
    "        disease0 = h_agg0[:self.num_diseases].to(self.device)\n",
    "        metabolite0 = h_agg0[self.num_diseases:self.num_diseases + self.num_metabolite].to(self.device)\n",
    "        disease1 = h_agg1[:self.num_diseases].to(self.device)\n",
    "        metabolite1 = h_agg2[self.num_diseases:self.num_diseases + self.num_metabolite].to(self.device)\n",
    "\n",
    "        semantic_embeddings1 = torch.stack((disease0, disease1), dim=1).to(self.device)\n",
    "        h1 = self.semantic_attention(semantic_embeddings1).to(self.device)\n",
    "        semantic_embeddings2 = torch.stack((metabolite0, metabolite1), dim=1).to(self.device)\n",
    "        h2 = self.semantic_attention(semantic_embeddings2).to(self.device)\n",
    "\n",
    "        h_d = torch.cat((h1, self.G.ndata['d_sim'][:self.num_diseases]), dim=1).to(self.device)\n",
    "        h_me = torch.cat((h2, self.G.ndata['me_sim'][self.num_diseases: self.num_diseases+self.num_metabolite]), dim=1).to(self.device)\n",
    "\n",
    "        h_me = self.dropout(F.elu(self.me_fc(h_me))).to(self.device)\n",
    "        h_d = self.dropout(F.elu(self.d_fc(h_d))).to(self.device)\n",
    "\n",
    "        h_me_final = torch.cat((h_ME, h_me), dim=1).to(self.device)\n",
    "        h_d_final = torch.cat((h_D, h_d), dim=1).to(self.device)\n",
    "\n",
    "        h = torch.cat((h_d_final, h_me_final), dim=0).to(self.device)\n",
    "        h = self.dropout(F.elu(self.h_fc(h))).to(self.device)\n",
    "\n",
    "        # 获取训练边或测试边的点的特征\n",
    "        h_diseases = h[diseases].to(self.device)\n",
    "        h_metabolite = h[metabolite].to(self.device)\n",
    "        # 解码\n",
    "        predict_score = self.BilinearDecoder(h_diseases, h_metabolite).to(self.device)\n",
    "        return predict_score"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:55:20.493698900Z",
     "start_time": "2024-05-08T00:55:20.473700500Z"
    }
   },
   "id": "d73f4a97a3f4f017",
   "execution_count": 3
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Bilinear decoder"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "5b3e027623e90650"
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "class BilinearDecoder(nn.Module):\n",
    "    def __init__(self, feature_size):\n",
    "        super(BilinearDecoder, self).__init__()\n",
    "        self.W = Parameter(torch.randn(feature_size, feature_size))\n",
    "\n",
    "    def forward(self, h_diseases, h_metabolite):\n",
    "        h_diseases0 = torch.mm(h_diseases, self.W)\n",
    "        h_metabolite0 = torch.mul(h_diseases0, h_metabolite)\n",
    "        h0 = h_metabolite0.sum(1)\n",
    "        h = torch.sigmoid(h0)\n",
    "        return h.unsqueeze(1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-05-08T00:55:20.508718500Z",
     "start_time": "2024-05-08T00:55:20.490703300Z"
    }
   },
   "id": "f55165cfd6d56123",
   "execution_count": 4
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
